{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "iris = load_iris()\n",
    "X, y = iris.data, iris.target\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y)\n",
    "\n",
    "clr = LogisticRegression()\n",
    "clr.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skl2onnx import convert_sklearn\n",
    "from skl2onnx.common.data_types import FloatTensorType\n",
    "\n",
    "initial_type = [('float_input', FloatTensorType([1, 4]))]\n",
    "onx = convert_sklearn(clr, initial_types=initial_type)\n",
    "with open(\"logreg_iris.onnx\", \"wb\") as f:\n",
    "    f.write(onx.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "model = onnx.load('logreg_iris.onnx')\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnxruntime as rt\n",
    "import numpy\n",
    "sess = rt.InferenceSession(\"logreg_iris.onnx\")\n",
    "input_name = sess.get_inputs()[0].name\n",
    "label_name = sess.get_outputs()[0].name\n",
    "probability_name = sess.get_outputs()[1].name\n",
    "pred_onx = sess.run([label_name, probability_name], {input_name: X_test[0].astype(numpy.float32)})\n",
    "\n",
    "# print info\n",
    "print('input_name: ' + input_name)\n",
    "print('label_name: ' + label_name)\n",
    "print('probability_name: ' + probability_name)\n",
    "print(X_test[0])\n",
    "print(pred_onx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
